Nested numerical solving problems

Scientific knowledge often takes the form of specific relationships expressed by systems of equations. For example:

If there is an analytic solution to the equation system, we can just include the solution in our statistical model like any other form of structural knowledge: easy! However, often we want to solve equations that are hard or impossible to solve analytically, but can be solved approximately using numerical methods.

This is tricky in the context of Hamiltonian Monte Carlo for two reasons:

  1. Computation: HMC requires many evaluations of the log probability density function and its gradients.
Important

At every evaluation, the sampler needs to solve the embedded equation system and find the gradients of the solution with respect to all model parameters.

  1. Extra source of error: how good of an approximation is good enough?

Reading:

Example

We have some tubes containing a substrate \(S\) and some biomass \(C\) that we think approximately follow the Monod equation for microbial growth:

\[\begin{align*} \frac{dC}{dt} &= \frac{\mu_{max}\cdot S(t)}{K_{S} + S(t)}\cdot C(t) \\ \frac{dS}{dt} &= -\gamma \cdot \frac{\mu_{max}\cdot S(t)}{K_{s} + S(t)} \cdot C(t) \end{align*}\]

We measured \(C\) and \(S\) at different timepoints in some experiments and we want to try and find out \(\mu_{max}\), \(K_{S}\) and \(\gamma\) for the different strains in the tubes.

You can read more about the Monod equation in Allen and Waclaw (2019).

What we know

\(\mu_{max}, K_S, \gamma, S, C\) are non-negative.

\(S(0)\) and \(C(0)\) vary a little by tube.

\(\mu_{max}, K_S, \gamma\) vary by strain.

Measurement noise is roughly proportional to measured quantity.

Statistical model

We use two regression models to describe the measurements:

\[\begin{align*} y_C &\sim LN(\ln{\hat{C}}, \sigma_{C}) \\ y_S &\sim LN(\ln{\hat{S}}, \sigma_{S}) \end{align*}\]

To capture the variation in parameters by tube and strain we add a hierarchical regression model:

\[\begin{align*} \ln{\mu_{max}} &\sim N(a_{\mu_{max}}, \tau_{\mu_max}) \\ \ln{\gamma} &\sim N(a_{gamma}, \tau_{\gamma}) \\ \ln{\mu_{K_S}} &\sim N(a_{K_S}, \tau_{K_S}) \end{align*}\]

To get a true abundance given some parameters we put an ode in the model:

\[ \hat{C}(t), \hat{S}(t) = \text{solve-monod-equation}(t, C_0, S_0, \mu_max, \gamma, K_S) \]

imports

import itertools

import arviz as az
import cmdstanpy
import pandas as pd
import numpy as np

from matplotlib import pyplot as plt

Specify true parameters

In order to avoid doing too much annoying handling of strings we assume that all the parts of the problem have meaningful 1-indexed integer labels: for example, species 1 is biomass.

This code specifies the dimensions of our problem.

N_strain = 4
N_tube = 16
N_timepoint = 20
duration = 15
strains = [i+1 for i in range(N_strain)]
tubes = [i+1 for i in range(N_tube)]
species = [1, 2]
measurement_timepoint_ixs = [4, 7, 12, 15, 17]
timepoints = pd.Series(
    np.linspace(0.01, duration, N_timepoint),
    name="time",
    index=range(1, N_timepoint+1)
)
SEED = 12345
rng = np.random.default_rng(seed=SEED)

This code defines some true values for the parameters - we will use these to generate fake data.

true_param_values = {
    "a_mu_max": -1.7,
    "a_ks": -1.3,
    "a_gamma": -0.6,
    "t_mu_max": 0.2,
    "t_ks": 0.3,
    "t_gamma": 0.13,
    "species_zero": [
        [np.exp(np.random.normal(-2.1, 0.1)), np.exp(np.random.normal(0.2, 0.1))]
        for _ in range(N_tube)
    ],
    "sigma_y": [0.08, 0.1],
    "ln_mu_max_z": np.random.normal(0, 1, size=N_strain).tolist(),
    "ln_ks_z": np.random.normal(0, 1, size=N_strain).tolist(),
    "ln_gamma_z": np.random.normal(0, 1, size=N_strain).tolist(),
}
for var in ["mu_max", "ks", "gamma"]:
    true_param_values[var] = np.exp(
        true_param_values[f"a_{var}"]
        + true_param_values[f"t_{var}"] * np.array(true_param_values[f"ln_{var}_z"])
    ).tolist()

A bit of data transformation

This code does some handy transformations on the data using pandas, giving us a table of information about the measurements.

tube_to_strain = pd.Series(
    [
        (i % N_strain) + 1 for i in range(N_tube)  # % operator finds remainder
    ], index=tubes, name="strain"
)
measurements = (
    pd.DataFrame(
        itertools.product(tubes, measurement_timepoint_ixs, species),
        columns=["tube", "timepoint", "species"],
        index=range(1, len(tubes) * len(measurement_timepoint_ixs) * len(species) + 1)
    )
    .join(tube_to_strain, on="tube")
    .join(timepoints, on="timepoint")
)

Generating a Stan input dictionary

This code puts the data in the correct format for cmdstanpy.

stan_input_structure = {
    "N_measurement": len(measurements),
    "N_timepoint": N_timepoint,
    "N_tube": N_tube,
    "N_strain": N_strain,
    "tube": measurements["tube"].values.tolist(),
    "measurement_timepoint": measurements["timepoint"].values.tolist(),
    "measured_species": measurements["species"].values.tolist(),
    "strain": tube_to_strain.values.tolist(),
    "timepoint_time": timepoints.values.tolist(),
}

This code defines some prior distributions for the model’s parameters

priors = {
    # parameters that can be negative:
    "prior_a_mu_max": [-1.8, 0.2],
    "prior_a_ks": [-1.3, 0.1],
    "prior_a_gamma": [-0.5, 0.1],
    # parameters that are non-negative:
    "prior_t_mu_max": [-1.4, 0.1],
    "prior_t_ks": [-1.2, 0.1],
    "prior_t_gamma": [-2, 0.1],
    "prior_species_zero": [[[-2.1, 0.1], [0.2, 0.1]]] * N_tube,
    "prior_sigma_y": [[-2.3, 0.15], [-2.3, 0.15]],
}

The next bit of code lets us configure Stan’s interface to the Sundials ODE solver.

ode_solver_configuration = {
    "abs_tol": 1e-7,
    "rel_tol": 1e-7,
    "max_num_steps": int(1e7)
}

Now we can put all the inputs together

stan_input_common = stan_input_structure | priors | ode_solver_configuration

Load the model

This code loads the Stan program at monod.stan as a CmdStanModel object and compiles it using cmdstan’s compiler.

model = cmdstanpy.CmdStanModel(stan_file="../src/stan/monod.stan")
print(model.code())
functions {
  real get_mu_at_t(real mu_max, real ks, real S_at_t) {
    return (mu_max * S_at_t) / (ks + S_at_t);
  }
  vector ddt(real t, vector species, real mu_max, real ks, real gamma) {
    real mu_at_t = get_mu_at_t(mu_max, ks, species[2]);
    vector[2] out;
    out[1] = mu_at_t * species[1];
    out[2] = -gamma * mu_at_t * species[1];
    return out;
  }
}
data {
  int<lower=1> N_measurement;
  int<lower=1> N_timepoint;
  int<lower=1> N_tube;
  int<lower=1> N_strain;
  array[N_measurement] int<lower=1, upper=N_tube> tube;
  array[N_measurement] int<lower=1, upper=N_timepoint> measurement_timepoint;
  array[N_measurement] int<lower=1, upper=2> measured_species;
  vector<lower=0>[N_measurement] y;
  array[N_tube] int<lower=1, upper=N_strain> strain;
  array[N_timepoint] real<lower=0> timepoint_time;
  array[N_tube, 2] vector[2] prior_species_zero;
  array[2] vector[2] prior_sigma_y;
  vector[2] prior_a_mu_max;
  vector[2] prior_a_ks;
  vector[2] prior_a_gamma;
  vector[2] prior_t_mu_max;
  vector[2] prior_t_gamma;
  vector[2] prior_t_ks;
  real<lower=0> abs_tol;
  real<lower=0> rel_tol;
  int<lower=1> max_num_steps;
  int<lower=0, upper=1> likelihood;
}
parameters {
  vector[N_strain] ln_mu_max_z;
  vector[N_strain] ln_ks_z;
  vector[N_strain] ln_gamma_z;
  real a_mu_max;
  real a_ks;
  real a_gamma;
  real<lower=0> t_mu_max;
  real<lower=0> t_ks;
  real<lower=0> t_gamma;
  array[N_tube] vector<lower=0>[2] species_zero;
  vector<lower=0>[2] sigma_y;
}
transformed parameters {
  vector[N_strain] mu_max = exp(a_mu_max + ln_mu_max_z * t_mu_max);
  vector[N_strain] ks = exp(a_ks + ln_ks_z * t_ks);
  vector[N_strain] gamma = exp(a_gamma + ln_gamma_z * t_gamma);
  array[N_tube, N_timepoint] vector[2] abundance;
  for (tube_t in 1 : N_tube) {
    abundance[tube_t] = ode_bdf_tol(ddt, species_zero[tube_t], 0,
                                    timepoint_time,
                                    abs_tol, rel_tol, max_num_steps,
                                    mu_max[strain[tube_t]],
                                    ks[strain[tube_t]], gamma[strain[tube_t]]);
  }
}
model {
  // priors
  ln_mu_max_z ~ std_normal();
  ln_ks_z ~ std_normal();
  ln_gamma_z ~ std_normal();
  a_mu_max ~ normal(prior_a_mu_max[1], prior_a_mu_max[2]);
  a_ks ~ normal(prior_a_ks[1], prior_a_ks[2]);
  a_gamma ~ normal(prior_a_gamma[1], prior_a_gamma[2]);
  t_mu_max ~ normal(prior_t_mu_max[1], prior_t_mu_max[2]);
  t_ks ~ normal(prior_t_ks[1], prior_t_ks[2]);
  t_gamma ~ normal(prior_t_gamma[1], prior_t_gamma[2]);
  for (s in 1 : 2) {
    sigma_y[s] ~ lognormal(prior_sigma_y[s, 1], prior_sigma_y[s, 2]);
    for (t in 1 : N_tube){
      species_zero[t, s] ~ lognormal(prior_species_zero[t, s, 1],
                                     prior_species_zero[t, s, 2]);
    }
  }
  // likelihood
  if (likelihood) {
    for (m in 1 : N_measurement) {
      real yhat = abundance[tube[m], measurement_timepoint[m], measured_species[m]];
      y[m] ~ lognormal(log(yhat), sigma_y[measured_species[m]]);
    }
  }
}
generated quantities {
  vector[N_measurement] yrep;
  vector[N_measurement] llik;
  for (m in 1 : N_measurement){
    real yhat = abundance[tube[m], measurement_timepoint[m], measured_species[m]];
    yrep[m] = lognormal_rng(log(yhat), sigma_y[measured_species[m]]);
    llik[m] = lognormal_lpdf(y[m] | log(yhat), sigma_y[measured_species[m]]);
  }
}

Sample in fixed param mode to generate fake data

stan_input_true = stan_input_common | {
    "y": np.ones(len(measurements)).tolist(),  # dummy values as we don't need measurements yet
    "likelihood": 0                            # we don't need to evaluate the likelihood
}
coords = {
    "strain": strains,
    "tube": tubes,
    "species": species,
    "timepoint": timepoints.index.values,
    "measurement": measurements.index.values
}
dims = {
    "abundance": ["tube", "timepoint", "species"],
    "mu_max": ["strain"],
    "ks": ["strain"],
    "gamma": ["strain"],
    "species_zero": ["tube", "species"],
    "y": ["measurement"],
    "yrep": ["measurement"],
    "llik": ["measurement"]
}

mcmc_true = model.sample(
    data=stan_input_true,
    iter_sampling=1,
    fixed_param=True,
    chains=1,
    refresh=1,
    inits=true_param_values,
    seed=SEED,
)
idata_true = az.from_cmdstanpy(
    mcmc_true,
    dims=dims,
    coords=coords,
    posterior_predictive={"y": "yrep"},
    log_likelihood="llik"
)
15:13:46 - cmdstanpy - INFO - CmdStan start processing
15:13:46 - cmdstanpy - INFO - CmdStan done processing.
                                                                                

Look at results

def plot_sim(true_abundance, fake_measurements, species_to_ax):
    f, axes = plt.subplots(1, 2, figsize=[9, 3])

    axes[species_to_ax[1]].set_title("Species 1")
    axes[species_to_ax[2]].set_title("Species 2")
    for ax in axes:
        ax.set_xlabel("Time")
        ax.set_ylabel("Abundance")
        for (tube_i, species_i), df_i in true_abundance.groupby(["tube", "species"]):
            ax = axes[species_to_ax[species_i]]
            fm = df_i.merge(
                fake_measurements.drop("time", axis=1),
                on=["tube", "species", "timepoint"]
            )
            ax.plot(
                df_i.set_index("time")["abundance"], color="black", linewidth=0.5
            )
            ax.scatter(
                fm["time"],
                fm["simulated_measurement"],
                color="r",
                marker="x",
                label="simulated measurement"
            )
    return f, axes

species_to_ax = {1: 0, 2: 1}
true_abundance = (
    idata_true.posterior["abundance"]
    .to_dataframe()
    .droplevel(["chain", "draw"])
    .join(timepoints, on="timepoint")
    .reset_index()
)
fake_measurements = measurements.join(
    idata_true.posterior_predictive["yrep"]
    .to_series()
    .droplevel(["chain", "draw"])
    .rename("simulated_measurement")
).copy()
f, axes = plot_sim(true_abundance, fake_measurements, species_to_ax)

f.savefig("img/monod_simulated_data.png")

Sample in prior mode

stan_input_prior = stan_input_common | {
    "y": fake_measurements["simulated_measurement"],
    "likelihood": 0
}
mcmc_prior = model.sample(
    data=stan_input_prior,
    iter_warmup=100,
    iter_sampling=100,
    chains=1,
    refresh=1,
    save_warmup=True,
    inits=true_param_values,
    seed=SEED,
)
idata_prior = az.from_cmdstanpy(
    mcmc_prior,
    dims=dims,
    coords=coords,
    posterior_predictive={"y": "yrep"},
    log_likelihood="llik"
)
idata_prior
15:13:46 - cmdstanpy - INFO - CmdStan start processing
15:14:17 - cmdstanpy - INFO - CmdStan done processing.
15:14:17 - cmdstanpy - WARNING - Non-fatal error during sampling:
Exception: ode_bdf_tol: initial state[2] is inf, but must be finite! (in 'monod.stan', line 56, column 4 to line 60, column 79)
Consider re-running with show_console=True if the above output is unclear!
                                                                                
arviz.InferenceData
    • <xarray.Dataset> Size: 564kB
      Dimensions:            (chain: 1, draw: 100, ln_mu_max_z_dim_0: 4,
                              ln_ks_z_dim_0: 4, ln_gamma_z_dim_0: 4, tube: 16,
                              species: 2, sigma_y_dim_0: 2, strain: 4, timepoint: 20)
      Coordinates:
        * chain              (chain) int64 8B 0
        * draw               (draw) int64 800B 0 1 2 3 4 5 6 ... 93 94 95 96 97 98 99
        * ln_mu_max_z_dim_0  (ln_mu_max_z_dim_0) int64 32B 0 1 2 3
        * ln_ks_z_dim_0      (ln_ks_z_dim_0) int64 32B 0 1 2 3
        * ln_gamma_z_dim_0   (ln_gamma_z_dim_0) int64 32B 0 1 2 3
        * tube               (tube) int64 128B 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
        * species            (species) int64 16B 1 2
        * sigma_y_dim_0      (sigma_y_dim_0) int64 16B 0 1
        * strain             (strain) int64 32B 1 2 3 4
        * timepoint          (timepoint) int64 160B 1 2 3 4 5 6 ... 15 16 17 18 19 20
      Data variables: (12/15)
          ln_mu_max_z        (chain, draw, ln_mu_max_z_dim_0) float64 3kB -0.2028 ....
          ln_ks_z            (chain, draw, ln_ks_z_dim_0) float64 3kB 0.5768 ... -1.08
          ln_gamma_z         (chain, draw, ln_gamma_z_dim_0) float64 3kB 1.083 ... ...
          a_mu_max           (chain, draw) float64 800B -1.79 -1.844 ... -2.034 -1.567
          a_ks               (chain, draw) float64 800B -1.339 -1.234 ... -1.431
          a_gamma            (chain, draw) float64 800B -0.4577 -0.7324 ... -0.3797
          ...                 ...
          species_zero       (chain, draw, tube, species) float64 26kB 0.1268 ... 1...
          sigma_y            (chain, draw, sigma_y_dim_0) float64 2kB 0.105 ... 0.0...
          mu_max             (chain, draw, strain) float64 3kB 0.1667 ... 0.2098
          ks                 (chain, draw, strain) float64 3kB 0.2642 ... 0.2383
          gamma              (chain, draw, strain) float64 3kB 0.6363 ... 0.6785
          abundance          (chain, draw, tube, timepoint, species) float64 512kB ...
      Attributes:
          created_at:                 2024-04-26T13:14:17.359859
          arviz_version:              0.17.0
          inference_library:          cmdstanpy
          inference_library_version:  1.2.1

    • <xarray.Dataset> Size: 130kB
      Dimensions:      (chain: 1, draw: 100, measurement: 160)
      Coordinates:
        * chain        (chain) int64 8B 0
        * draw         (draw) int64 800B 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99
        * measurement  (measurement) int64 1kB 1 2 3 4 5 6 ... 155 156 157 158 159 160
      Data variables:
          yrep         (chain, draw, measurement) float64 128kB 0.2099 ... 0.6809
      Attributes:
          created_at:                 2024-04-26T13:14:17.365163
          arviz_version:              0.17.0
          inference_library:          cmdstanpy
          inference_library_version:  1.2.1

    • <xarray.Dataset> Size: 130kB
      Dimensions:      (chain: 1, draw: 100, measurement: 160)
      Coordinates:
        * chain        (chain) int64 8B 0
        * draw         (draw) int64 800B 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99
        * measurement  (measurement) int64 1kB 1 2 3 4 5 6 ... 155 156 157 158 159 160
      Data variables:
          llik         (chain, draw, measurement) float64 128kB 3.016 0.8591 ... 1.201
      Attributes:
          created_at:                 2024-04-26T13:14:17.366055
          arviz_version:              0.17.0
          inference_library:          cmdstanpy
          inference_library_version:  1.2.1

    • <xarray.Dataset> Size: 11kB
      Dimensions:          (chain: 1, draw: 200)
      Coordinates:
        * chain            (chain) int64 8B 0
        * draw             (draw) int64 2kB 0 1 2 3 4 5 6 ... 194 195 196 197 198 199
      Data variables:
          lp               (chain, draw) float64 2kB -470.8 -470.8 ... -440.7 -439.6
          acceptance_rate  (chain, draw) float64 2kB 0.8477 0.0 0.0 ... 0.7256 0.6523
          step_size        (chain, draw) float64 2kB 0.0625 10.91 ... 0.06444 0.06444
          tree_depth       (chain, draw) int64 2kB 3 0 0 4 8 7 7 6 ... 6 6 6 6 6 6 6 6
          n_steps          (chain, draw) int64 2kB 7 1 1 15 255 127 ... 63 63 63 63 63
          diverging        (chain, draw) bool 200B False True True ... False False
          energy           (chain, draw) float64 2kB 564.6 498.1 495.0 ... 467.2 469.0
      Attributes:
          created_at:                 2024-04-26T13:14:17.363711
          arviz_version:              0.17.0
          inference_library:          cmdstanpy
          inference_library_version:  1.2.1

We can find the prior intervals for the true abundance and plot them in the graph.

prior_abundances = (
    idata_prior
    .posterior["abundance"]
    .to_dataframe()
    .reset_index()
    .join(timepoints, on="timepoint")
)
n_sample = 20
chains_and_draws = rng.choice(prior_abundances[["chain", "draw"]].values, 10)
f, axes = plot_sim(true_abundance, fake_measurements, species_to_ax)
for tube_i in tubes:
    for species_j in species:
        abundance_sample = (
            prior_abundances.loc[
                lambda df: (df["tube"] == tube_i) & (df["species"] == species_j)
            ]
            .set_index(["chain", "draw"])
            .loc[chains_and_draws.tolist()]
            .reset_index()
        )
        axes[species_to_ax[species_j]].plot(
            abundance_sample.set_index("time")["abundance"],
            alpha=0.5, color="skyblue", zorder=-1
        )
f.savefig("img/monod_priors.png")

Sample in posterior mode

stan_input_posterior = stan_input_common | {
    "y": fake_measurements["simulated_measurement"],
    "likelihood": 1
}
mcmc_posterior = model.sample(
    data=stan_input_posterior,
    iter_warmup=300,
    iter_sampling=300,
    chains=4,
    refresh=1,
    inits=true_param_values,
    seed=SEED,
)
idata_posterior = az.from_cmdstanpy(
    mcmc_posterior,
    dims=dims,
    coords=coords,
    posterior_predictive={"y": "yrep"},
    log_likelihood="llik"
)
idata_posterior
15:14:18 - cmdstanpy - INFO - CmdStan start processing
15:19:25 - cmdstanpy - INFO - CmdStan done processing.
15:19:25 - cmdstanpy - WARNING - Non-fatal error during sampling:
Exception: ode_bdf_tol: ode parameters and data is inf, but must be finite! (in 'monod.stan', line 56, column 4 to line 60, column 79)
    Exception: ode_bdf_tol: ode parameters and data is inf, but must be finite! (in 'monod.stan', line 56, column 4 to line 60, column 79)
    Exception: lognormal_lpdf: Location parameter is nan, but must be finite! (in 'monod.stan', line 85, column 6 to column 64)
    Exception: ode_bdf_tol: ode parameters and data is inf, but must be finite! (in 'monod.stan', line 56, column 4 to line 60, column 79)
    Exception: ode_bdf_tol: ode parameters and data is inf, but must be finite! (in 'monod.stan', line 56, column 4 to line 60, column 79)
    Exception: lognormal_lpdf: Location parameter is nan, but must be finite! (in 'monod.stan', line 85, column 6 to column 64)
Exception: ode_bdf_tol: ode parameters and data is inf, but must be finite! (in 'monod.stan', line 56, column 4 to line 60, column 79)
    Exception: ode_bdf_tol: ode parameters and data is inf, but must be finite! (in 'monod.stan', line 56, column 4 to line 60, column 79)
    Exception: lognormal_lpdf: Location parameter is nan, but must be finite! (in 'monod.stan', line 85, column 6 to column 64)
    Exception: lognormal_lpdf: Location parameter is nan, but must be finite! (in 'monod.stan', line 85, column 6 to column 64)
    Exception: CVode(cvodes_mem, t_final, nv_state_, &t_init, CV_NORMAL) failed with error flag -4: 
Exception: ode_bdf_tol: ode parameters and data is inf, but must be finite! (in 'monod.stan', line 56, column 4 to line 60, column 79)
    Exception: ode_bdf_tol: ode parameters and data is inf, but must be finite! (in 'monod.stan', line 56, column 4 to line 60, column 79)
    Exception: lognormal_lpdf: Location parameter is nan, but must be finite! (in 'monod.stan', line 85, column 6 to column 64)
    Exception: lognormal_lpdf: Location parameter is nan, but must be finite! (in 'monod.stan', line 85, column 6 to column 64)
    Exception: ode_bdf_tol: initial state[2] is inf, but must be finite! (in 'monod.stan', line 56, column 4 to line 60, column 79)
    Exception: ode_bdf_tol: ode parameters and data is inf, but must be finite! (in 'monod.stan', line 56, column 4 to line 60, column 79)
Exception: ode_bdf_tol: ode parameters and data is inf, but must be finite! (in 'monod.stan', line 56, column 4 to line 60, column 79)
    Exception: ode_bdf_tol: ode parameters and data is inf, but must be finite! (in 'monod.stan', line 56, column 4 to line 60, column 79)
    Exception: lognormal_lpdf: Location parameter is nan, but must be finite! (in 'monod.stan', line 85, column 6 to column 64)
    Exception: lognormal_lpdf: Location parameter is nan, but must be finite! (in 'monod.stan', line 85, column 6 to column 64)
    Exception: ode_bdf_tol: initial state[1] is inf, but must be finite! (in 'monod.stan', line 56, column 4 to line 60, column 79)
    Exception: CVode(cvodes_mem, t_final, nv_state_, &t_init, CV_NORMAL) failed with error flag -1: 
    Exception: lognormal_lpdf: Location parameter is nan, but must be finite! (in 'monod.stan', line 85, column 6 to column 64)
Consider re-running with show_console=True if the above output is unclear!
                                                                                                                                                                                                                                                                                                                                
arviz.InferenceData
    • <xarray.Dataset> Size: 7MB
      Dimensions:            (chain: 4, draw: 300, ln_mu_max_z_dim_0: 4,
                              ln_ks_z_dim_0: 4, ln_gamma_z_dim_0: 4, tube: 16,
                              species: 2, sigma_y_dim_0: 2, strain: 4, timepoint: 20)
      Coordinates:
        * chain              (chain) int64 32B 0 1 2 3
        * draw               (draw) int64 2kB 0 1 2 3 4 5 ... 294 295 296 297 298 299
        * ln_mu_max_z_dim_0  (ln_mu_max_z_dim_0) int64 32B 0 1 2 3
        * ln_ks_z_dim_0      (ln_ks_z_dim_0) int64 32B 0 1 2 3
        * ln_gamma_z_dim_0   (ln_gamma_z_dim_0) int64 32B 0 1 2 3
        * tube               (tube) int64 128B 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
        * species            (species) int64 16B 1 2
        * sigma_y_dim_0      (sigma_y_dim_0) int64 16B 0 1
        * strain             (strain) int64 32B 1 2 3 4
        * timepoint          (timepoint) int64 160B 1 2 3 4 5 6 ... 15 16 17 18 19 20
      Data variables: (12/15)
          ln_mu_max_z        (chain, draw, ln_mu_max_z_dim_0) float64 38kB -0.2781 ...
          ln_ks_z            (chain, draw, ln_ks_z_dim_0) float64 38kB 0.3355 ... 0...
          ln_gamma_z         (chain, draw, ln_gamma_z_dim_0) float64 38kB -0.829 .....
          a_mu_max           (chain, draw) float64 10kB -1.762 -1.741 ... -1.713
          a_ks               (chain, draw) float64 10kB -1.436 -1.297 ... -1.39 -1.376
          a_gamma            (chain, draw) float64 10kB -0.5375 -0.5353 ... -0.6106
          ...                 ...
          species_zero       (chain, draw, tube, species) float64 307kB 0.13 ... 1.181
          sigma_y            (chain, draw, sigma_y_dim_0) float64 19kB 0.09076 ... ...
          mu_max             (chain, draw, strain) float64 38kB 0.1666 ... 0.2052
          ks                 (chain, draw, strain) float64 38kB 0.2382 ... 0.2543
          gamma              (chain, draw, strain) float64 38kB 0.5818 ... 0.5421
          abundance          (chain, draw, tube, timepoint, species) float64 6MB 0....
      Attributes:
          created_at:                 2024-04-26T13:19:26.043680
          arviz_version:              0.17.0
          inference_library:          cmdstanpy
          inference_library_version:  1.2.1

    • <xarray.Dataset> Size: 2MB
      Dimensions:      (chain: 4, draw: 300, measurement: 160)
      Coordinates:
        * chain        (chain) int64 32B 0 1 2 3
        * draw         (draw) int64 2kB 0 1 2 3 4 5 6 ... 293 294 295 296 297 298 299
        * measurement  (measurement) int64 1kB 1 2 3 4 5 6 ... 155 156 157 158 159 160
      Data variables:
          yrep         (chain, draw, measurement) float64 2MB 0.1833 1.129 ... 0.6981
      Attributes:
          created_at:                 2024-04-26T13:19:26.049491
          arviz_version:              0.17.0
          inference_library:          cmdstanpy
          inference_library_version:  1.2.1

    • <xarray.Dataset> Size: 2MB
      Dimensions:      (chain: 4, draw: 300, measurement: 160)
      Coordinates:
        * chain        (chain) int64 32B 0 1 2 3
        * draw         (draw) int64 2kB 0 1 2 3 4 5 6 ... 293 294 295 296 297 298 299
        * measurement  (measurement) int64 1kB 1 2 3 4 5 6 ... 155 156 157 158 159 160
      Data variables:
          llik         (chain, draw, measurement) float64 2MB 3.186 0.5433 ... 1.375
      Attributes:
          created_at:                 2024-04-26T13:19:26.050691
          arviz_version:              0.17.0
          inference_library:          cmdstanpy
          inference_library_version:  1.2.1

    • <xarray.Dataset> Size: 61kB
      Dimensions:          (chain: 4, draw: 300)
      Coordinates:
        * chain            (chain) int64 32B 0 1 2 3
        * draw             (draw) int64 2kB 0 1 2 3 4 5 6 ... 294 295 296 297 298 299
      Data variables:
          lp               (chain, draw) float64 10kB -317.8 -311.0 ... -330.2 -332.7
          acceptance_rate  (chain, draw) float64 10kB 0.9637 0.7829 ... 0.9201 0.9813
          step_size        (chain, draw) float64 10kB 0.06185 0.06185 ... 0.06799
          tree_depth       (chain, draw) int64 10kB 6 6 6 6 6 6 6 6 ... 6 6 6 6 6 6 6
          n_steps          (chain, draw) int64 10kB 63 63 63 63 63 ... 63 63 63 63 63
          diverging        (chain, draw) bool 1kB False False False ... False False
          energy           (chain, draw) float64 10kB 335.8 336.9 ... 347.8 367.7
      Attributes:
          created_at:                 2024-04-26T13:19:26.047799
          arviz_version:              0.17.0
          inference_library:          cmdstanpy
          inference_library_version:  1.2.1

Diagnostics: is the posterior ok?

First check the sample_stats group to see if there were any divergent transitions and if the lp parameter converged.

az.summary(idata_posterior.sample_stats)
/Users/tedgro/repos/biosustain/bayesian_statistics_for_computational_biology/.venv/lib/python3.12/site-packages/arviz/stats/diagnostics.py:592: RuntimeWarning: invalid value encountered in scalar divide
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
lp -318.452 5.296 -327.970 -308.436 0.236 0.167 517.0 712.0 1.000000e+00
acceptance_rate 0.928 0.096 0.733 1.000 0.003 0.002 1145.0 1260.0 1.010000e+00
step_size 0.057 0.009 0.044 0.068 0.004 0.003 4.0 4.0 5.859337e+15
tree_depth 6.131 0.337 6.000 7.000 0.110 0.080 9.0 9.0 1.380000e+00
n_steps 75.053 25.033 63.000 127.000 9.606 7.108 7.0 7.0 1.670000e+00
diverging 0.000 0.000 0.000 0.000 0.000 0.000 1200.0 1200.0 NaN
energy 344.577 7.319 330.529 357.918 0.346 0.245 452.0 646.0 1.010000e+00

Next check the parameter-by-parameter summary

az.summary(idata_posterior)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
ln_mu_max_z[0] -1.421 0.576 -2.497 -0.378 0.024 0.017 564.0 798.0 1.01
ln_mu_max_z[1] -1.901 0.588 -3.024 -0.801 0.024 0.017 604.0 800.0 1.00
ln_mu_max_z[2] 3.417 0.652 2.182 4.555 0.025 0.018 687.0 652.0 1.00
ln_mu_max_z[3] 0.292 0.526 -0.686 1.231 0.024 0.017 495.0 576.0 1.00
ln_ks_z[0] 0.024 1.061 -1.888 2.174 0.029 0.034 1396.0 705.0 1.00
... ... ... ... ... ... ... ... ... ...
abundance[16, 18, 2] 0.549 0.037 0.483 0.619 0.001 0.001 1927.0 929.0 1.00
abundance[16, 19, 1] 1.143 0.047 1.057 1.232 0.001 0.001 2007.0 927.0 1.00
abundance[16, 19, 2] 0.484 0.038 0.404 0.544 0.001 0.001 1779.0 930.0 1.00
abundance[16, 20, 1] 1.260 0.053 1.161 1.358 0.001 0.001 1881.0 953.0 1.00
abundance[16, 20, 2] 0.415 0.038 0.345 0.483 0.001 0.001 1630.0 983.0 1.00

704 rows × 9 columns

Show posterior intervals

posterior_abundances = (
    idata_posterior
    .posterior["abundance"]
    .to_dataframe()
    .reset_index()
    .join(timepoints, on="timepoint")
)
n_sample = 20
chains_and_draws = rng.choice(
    posterior_abundances[["chain", "draw"]].values, 10
)
f, axes = plot_sim(true_abundance, fake_measurements, species_to_ax)
for tube_i in tubes:
    for species_j in species:
        abundance_sample = (
            posterior_abundances.loc[
                lambda df: (df["tube"] == tube_i) & (df["species"] == species_j)
            ]
            .set_index(["chain", "draw"])
            .loc[chains_and_draws.tolist()]
            .reset_index()
        )
        axes[species_to_ax[species_j]].plot(
            abundance_sample.set_index("time")["abundance"],
            alpha=0.5, color="skyblue", zorder=-1
        )
f.savefig("img/monod_posteriors.png")

look at the posterior

The next few cells use arviz’s plot_posterior function to plot the marginal posterior distributions for some of the model’s parameters:

f, axes = plt.subplots(1, 4, figsize=[10, 4])
axes = az.plot_posterior(
    idata_posterior,
    kind="hist",
    bins=20,
    var_names=["gamma"],
    ax=axes,
    point_estimate=None,
    hdi_prob="hide"
)
for ax, true_value in zip(axes, true_param_values["gamma"]):
    ax.axvline(true_value, color="red")

f, axes = plt.subplots(1, 4, figsize=[10, 4])
axes = az.plot_posterior(
    idata_posterior,
    kind="hist",
    bins=20,
    var_names=["mu_max"],
    ax=axes,
    point_estimate=None,
    hdi_prob="hide"
)
for ax, true_value in zip(axes, true_param_values["mu_max"]):
    ax.axvline(true_value, color="red")

f, axes = plt.subplots(1, 4, figsize=[10, 4])
axes = az.plot_posterior(
    idata_posterior,
    kind="hist",
    bins=20,
    var_names=["ks"],
    ax=axes,
    point_estimate=None,
    hdi_prob="hide"
)
for ax, true_value in zip(axes, true_param_values["ks"]):
    ax.axvline(true_value, color="red")

References

Allen, Rosalind J, and Bartłomiej Waclaw. 2019. “Bacterial Growth: A Statistical Physicist’s Guide.” Reports on Progress in Physics. Physical Society (Great Britain) 82 (1): 016601. https://doi.org/10.1088/1361-6633/aae546.
Timonen, Juho, Nikolas Siccha, Ben Bales, Harri Lähdesmäki, and Aki Vehtari. 2022. “An Importance Sampling Approach for Reliable and Efficient Inference in Bayesian Ordinary Differential Equation Models.” arXiv. https://doi.org/10.48550/arXiv.2205.09059.